use std::collections::HashMap;
use std::sync::Arc;
use a2a_protocol_types::artifact::Artifact;
use a2a_protocol_types::events::{TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
use chrono::Utc;
use super::error::A2aError;
use super::state_machine::can_transition_to;
use super::task_store::{TaskStore, TaskStoreEntry};
pub struct V1Executor {
task_store: Arc<dyn TaskStore>,
}
impl V1Executor {
pub fn new(task_store: Arc<dyn TaskStore>) -> Self {
Self { task_store }
}
pub async fn create_task(
&self,
task_id: &str,
context_id: &str,
) -> Result<TaskStoreEntry, A2aError> {
let now = Utc::now();
let entry = TaskStoreEntry {
id: task_id.to_string(),
context_id: context_id.to_string(),
status: TaskStatus::with_timestamp(TaskState::Submitted),
artifacts: Vec::new(),
history: Vec::new(),
metadata: HashMap::new(),
push_configs: Vec::new(),
created_at: now,
updated_at: now,
};
self.task_store.create_task(entry.clone()).await?;
Ok(entry)
}
pub async fn transition_state(
&self,
task_id: &str,
context_id: &str,
new_state: TaskState,
message: Option<String>,
) -> Result<TaskStatusUpdateEvent, A2aError> {
let current = self.task_store.get_task(task_id).await?;
can_transition_to(current.status.state, new_state)?;
let mut status = TaskStatus::with_timestamp(new_state);
if let Some(msg_text) = message {
status.message = Some(a2a_protocol_types::Message {
id: a2a_protocol_types::MessageId::new(format!("status-msg-{task_id}")),
role: a2a_protocol_types::MessageRole::Agent,
parts: vec![a2a_protocol_types::Part::text(msg_text)],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
});
}
self.task_store.update_status(task_id, status.clone()).await?;
let metadata = if current.metadata.is_empty() {
None
} else {
let obj: serde_json::Map<String, serde_json::Value> =
current.metadata.into_iter().collect();
Some(serde_json::Value::Object(obj))
};
Ok(TaskStatusUpdateEvent {
task_id: TaskId::new(task_id),
context_id: ContextId::new(context_id),
status,
metadata,
})
}
pub async fn record_artifact(
&self,
task_id: &str,
context_id: &str,
artifact: Artifact,
) -> Result<TaskArtifactUpdateEvent, A2aError> {
let current = self.task_store.get_task(task_id).await?;
self.task_store.add_artifact(task_id, artifact.clone()).await?;
let metadata = if current.metadata.is_empty() {
None
} else {
let obj: serde_json::Map<String, serde_json::Value> =
current.metadata.into_iter().collect();
Some(serde_json::Value::Object(obj))
};
Ok(TaskArtifactUpdateEvent {
task_id: TaskId::new(task_id),
context_id: ContextId::new(context_id),
artifact,
append: None,
last_chunk: None,
metadata,
})
}
pub async fn fail_task(
&self,
task_id: &str,
context_id: &str,
error_message: &str,
) -> Result<TaskStatusUpdateEvent, A2aError> {
self.transition_state(
task_id,
context_id,
TaskState::Failed,
Some(error_message.to_string()),
)
.await
}
pub fn task_store(&self) -> &Arc<dyn TaskStore> {
&self.task_store
}
}
#[cfg(test)]
mod tests {
use super::super::task_store::InMemoryTaskStore;
use super::*;
fn make_executor() -> V1Executor {
V1Executor::new(Arc::new(InMemoryTaskStore::new()))
}
fn assert_valid_timestamp(ts: &Option<String>) {
let ts = ts.as_ref().expect("timestamp should be Some");
assert!(ts.contains('T'), "timestamp should contain 'T': {ts}");
assert!(ts.len() >= 19, "timestamp should be at least 19 chars: {ts}");
}
#[tokio::test]
async fn create_task_persists_with_submitted_state() {
let executor = make_executor();
let entry = executor.create_task("t1", "ctx-1").await.unwrap();
assert_eq!(entry.id, "t1");
assert_eq!(entry.context_id, "ctx-1");
assert_eq!(entry.status.state, TaskState::Submitted);
assert_valid_timestamp(&entry.status.timestamp);
let stored = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(stored.id, "t1");
assert_eq!(stored.status.state, TaskState::Submitted);
assert_valid_timestamp(&stored.status.timestamp);
}
#[tokio::test]
async fn transition_state_validates_and_persists() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
let event =
executor.transition_state("t1", "ctx-1", TaskState::Working, None).await.unwrap();
assert_eq!(event.task_id, TaskId::new("t1"));
assert_eq!(event.context_id, ContextId::new("ctx-1"));
assert_eq!(event.status.state, TaskState::Working);
assert_valid_timestamp(&event.status.timestamp);
let stored = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(stored.status.state, TaskState::Working);
assert_valid_timestamp(&stored.status.timestamp);
}
#[tokio::test]
async fn transition_state_with_message() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
let event = executor
.transition_state(
"t1",
"ctx-1",
TaskState::Working,
Some("processing started".to_string()),
)
.await
.unwrap();
assert!(event.status.message.is_some());
let msg = event.status.message.unwrap();
assert_eq!(msg.parts.len(), 1);
}
#[tokio::test]
async fn transition_state_rejects_invalid_transition() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
let err =
executor.transition_state("t1", "ctx-1", TaskState::Completed, None).await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("SUBMITTED"));
assert!(msg.contains("COMPLETED"));
}
#[tokio::test]
async fn transition_state_task_not_found() {
let executor = make_executor();
let err = executor
.transition_state("nonexistent", "ctx-1", TaskState::Working, None)
.await
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("nonexistent"));
}
#[tokio::test]
async fn record_artifact_persists_and_returns_event() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
let artifact = Artifact::new(
a2a_protocol_types::ArtifactId::new("art-1"),
vec![a2a_protocol_types::Part::text("result")],
);
let event = executor.record_artifact("t1", "ctx-1", artifact).await.unwrap();
assert_eq!(event.task_id, TaskId::new("t1"));
assert_eq!(event.context_id, ContextId::new("ctx-1"));
assert_eq!(event.artifact.id, a2a_protocol_types::ArtifactId::new("art-1"));
let stored = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(stored.artifacts.len(), 1);
}
#[tokio::test]
async fn record_artifact_task_not_found() {
let executor = make_executor();
let artifact = Artifact::new(
a2a_protocol_types::ArtifactId::new("art-1"),
vec![a2a_protocol_types::Part::text("result")],
);
let err = executor.record_artifact("nonexistent", "ctx-1", artifact).await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("nonexistent"));
}
#[tokio::test]
async fn fail_task_transitions_to_failed() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
executor.transition_state("t1", "ctx-1", TaskState::Working, None).await.unwrap();
let event = executor.fail_task("t1", "ctx-1", "something went wrong").await.unwrap();
assert_eq!(event.status.state, TaskState::Failed);
assert!(event.status.message.is_some());
let stored = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(stored.status.state, TaskState::Failed);
}
#[tokio::test]
async fn fail_task_from_terminal_state_is_rejected() {
let executor = make_executor();
executor.create_task("t1", "ctx-1").await.unwrap();
executor.transition_state("t1", "ctx-1", TaskState::Working, None).await.unwrap();
executor.transition_state("t1", "ctx-1", TaskState::Completed, None).await.unwrap();
let err = executor.fail_task("t1", "ctx-1", "late error").await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("COMPLETED"));
assert!(msg.contains("FAILED"));
}
#[tokio::test]
async fn full_lifecycle_submitted_working_completed() {
let executor = make_executor();
let entry = executor.create_task("t1", "ctx-1").await.unwrap();
assert_eq!(entry.status.state, TaskState::Submitted);
assert_valid_timestamp(&entry.status.timestamp);
let event =
executor.transition_state("t1", "ctx-1", TaskState::Working, None).await.unwrap();
assert_eq!(event.status.state, TaskState::Working);
assert_valid_timestamp(&event.status.timestamp);
let artifact = Artifact::new(
a2a_protocol_types::ArtifactId::new("art-1"),
vec![a2a_protocol_types::Part::text("output")],
);
let art_event = executor.record_artifact("t1", "ctx-1", artifact).await.unwrap();
assert_eq!(art_event.artifact.id, a2a_protocol_types::ArtifactId::new("art-1"));
let event =
executor.transition_state("t1", "ctx-1", TaskState::Completed, None).await.unwrap();
assert_eq!(event.status.state, TaskState::Completed);
assert_valid_timestamp(&event.status.timestamp);
let stored = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(stored.status.state, TaskState::Completed);
assert_eq!(stored.artifacts.len(), 1);
assert_valid_timestamp(&stored.status.timestamp);
}
#[tokio::test]
async fn metadata_included_in_events() {
let store = Arc::new(InMemoryTaskStore::new());
let executor = V1Executor::new(store.clone());
let now = Utc::now();
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), serde_json::json!("value"));
let entry = TaskStoreEntry {
id: "t1".to_string(),
context_id: "ctx-1".to_string(),
status: TaskStatus::new(TaskState::Submitted),
artifacts: Vec::new(),
history: Vec::new(),
metadata,
push_configs: Vec::new(),
created_at: now,
updated_at: now,
};
store.create_task(entry).await.unwrap();
let event =
executor.transition_state("t1", "ctx-1", TaskState::Working, None).await.unwrap();
assert!(event.metadata.is_some());
let meta = event.metadata.unwrap();
assert_eq!(meta["key"], "value");
}
#[tokio::test]
async fn context_id_included_in_status_event() {
let executor = make_executor();
executor.create_task("t1", "ctx-abc").await.unwrap();
let event =
executor.transition_state("t1", "ctx-abc", TaskState::Working, None).await.unwrap();
assert_eq!(event.context_id, ContextId::new("ctx-abc"));
}
#[tokio::test]
async fn context_id_included_in_artifact_event() {
let executor = make_executor();
executor.create_task("t1", "ctx-abc").await.unwrap();
let artifact = Artifact::new(
a2a_protocol_types::ArtifactId::new("art-1"),
vec![a2a_protocol_types::Part::text("data")],
);
let event = executor.record_artifact("t1", "ctx-abc", artifact).await.unwrap();
assert_eq!(event.context_id, ContextId::new("ctx-abc"));
}
#[tokio::test]
async fn task_store_accessor() {
let store = Arc::new(InMemoryTaskStore::new());
let executor = V1Executor::new(store.clone());
executor.create_task("t1", "ctx-1").await.unwrap();
let task = executor.task_store().get_task("t1").await.unwrap();
assert_eq!(task.id, "t1");
}
}