use std::pin::Pin;
use std::sync::Arc;
use futures::Stream;
use crate::application::{HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast};
use crate::domain::{
A2AError, AgentCard, DeleteTaskPushNotificationConfigParams,
GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, ListTasksParams,
ListTasksResult, Message, Task, TaskId, TaskPushNotificationConfig,
};
use crate::port::{
AsyncMessageHandler, AsyncNotificationManager, AsyncNotificationManagerExt, AsyncPushNotifier,
AsyncStreamingHandler, AsyncTaskLifecycle, AsyncTaskQuery, SeqEvent,
};
use crate::services::server::AgentInfoProvider;
pub type UpdateStream = Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>;
#[derive(Clone)]
pub struct TaskService {
message_handler: Arc<dyn AsyncMessageHandler>,
task_lifecycle: Arc<dyn AsyncTaskLifecycle>,
task_query: Arc<dyn AsyncTaskQuery>,
notification_manager: Arc<dyn AsyncNotificationManager>,
agent_info: Arc<dyn AgentInfoProvider>,
streaming_handler: Arc<dyn AsyncStreamingHandler>,
push_notifier: Arc<dyn AsyncPushNotifier>,
}
impl TaskService {
pub fn new(
message_handler: impl AsyncMessageHandler + 'static,
tasks: impl AsyncTaskLifecycle + AsyncTaskQuery + 'static,
notification_manager: impl AsyncNotificationManager + 'static,
agent_info: impl AgentInfoProvider + 'static,
streaming_handler: impl AsyncStreamingHandler + 'static,
push_notifier: impl AsyncPushNotifier + 'static,
) -> Self {
let tasks = Arc::new(tasks);
Self {
message_handler: Arc::new(message_handler),
task_lifecycle: tasks.clone(),
task_query: tasks,
notification_manager: Arc::new(notification_manager),
agent_info: Arc::new(agent_info),
streaming_handler: Arc::new(streaming_handler),
push_notifier: Arc::new(push_notifier),
}
}
pub fn with_handler(
handler: impl AsyncMessageHandler
+ AsyncTaskLifecycle
+ AsyncTaskQuery
+ AsyncNotificationManager
+ 'static,
agent_info: impl AgentInfoProvider + 'static,
streaming_handler: impl AsyncStreamingHandler + 'static,
push_notifier: impl AsyncPushNotifier + 'static,
) -> Self {
let handler = Arc::new(handler);
Self {
message_handler: handler.clone(),
task_lifecycle: handler.clone(),
task_query: handler.clone(),
notification_manager: handler,
agent_info: Arc::new(agent_info),
streaming_handler: Arc::new(streaming_handler),
push_notifier: Arc::new(push_notifier),
}
}
pub fn with_streaming_handler(
mut self,
streaming_handler: impl AsyncStreamingHandler + 'static,
) -> Self {
self.streaming_handler = Arc::new(streaming_handler);
self
}
pub fn with_push_notifier(mut self, push_notifier: impl AsyncPushNotifier + 'static) -> Self {
self.push_notifier = Arc::new(push_notifier);
self
}
pub async fn send_message(
&self,
task_id: &str,
message: &Message,
session_id: Option<&str>,
push_config: Option<TaskPushNotificationConfig>,
history_limit: Option<u32>,
) -> Result<Task, A2AError> {
if let Some(mut push_config) = push_config {
push_config.task_id = task_id.to_string();
self.notification_manager
.set_validated(&push_config)
.await?;
}
let mut task = self
.message_handler
.process_message(task_id, message, session_id)
.await?;
if let Some(limit) = history_limit {
task = task.with_limited_history(Some(limit));
}
Ok(task)
}
pub async fn send_streaming_message(
&self,
task_id: &str,
message: &Message,
session_id: Option<&str>,
push_config: Option<TaskPushNotificationConfig>,
history_limit: Option<u32>,
) -> Result<(Task, UpdateStream), A2AError> {
if let Some(mut push_config) = push_config {
push_config.task_id = task_id.to_string();
self.notification_manager
.set_validated(&push_config)
.await?;
}
let update_stream = self
.streaming_handler
.start_task_streaming(task_id, None)
.await?;
let mut task = self
.message_handler
.process_message(task_id, message, session_id)
.await?;
if let Some(limit) = history_limit {
task = task.with_limited_history(Some(limit));
}
Ok((task, update_stream))
}
pub async fn get(&self, id: &TaskId, history_length: Option<u32>) -> Result<Task, A2AError> {
self.task_lifecycle.get(id, history_length).await
}
pub async fn list(&self, params: &ListTasksParams) -> Result<ListTasksResult, A2AError> {
self.task_query.list(params).await
}
pub async fn cancel(&self, id: &TaskId) -> Result<Task, A2AError> {
self.cancel_and_broadcast(id).await
}
pub async fn subscribe(
&self,
task_id: &str,
from_event_id: Option<u64>,
) -> Result<(Option<Task>, UpdateStream), A2AError> {
let id: TaskId = task_id.parse()?;
let initial_task = match self.task_lifecycle.get(&id, None).await {
Ok(task) => Some(task),
Err(A2AError::TaskNotFound(_)) => None,
Err(e) => return Err(e),
};
let update_stream = self
.streaming_handler
.start_task_streaming(task_id, from_event_id)
.await?;
Ok((initial_task, update_stream))
}
pub async fn set_push_config(
&self,
config: &TaskPushNotificationConfig,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.notification_manager.set_validated(config).await
}
pub async fn get_push_config(
&self,
params: &GetTaskPushNotificationConfigParams,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.notification_manager.get_config(params).await
}
pub async fn list_push_configs(
&self,
params: &ListTaskPushNotificationConfigsParams,
) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
self.notification_manager.list_configs(params).await
}
pub async fn delete_push_config(
&self,
params: &DeleteTaskPushNotificationConfigParams,
) -> Result<(), A2AError> {
self.notification_manager.delete_config(params).await
}
pub async fn extended_agent_card(&self) -> Result<AgentCard, A2AError> {
self.agent_info.get_authenticated_extended_card().await
}
}
impl HasTaskLifecycle for TaskService {
fn lifecycle(&self) -> &dyn AsyncTaskLifecycle {
self.task_lifecycle.as_ref()
}
}
impl HasStreaming for TaskService {
fn streaming(&self) -> &dyn AsyncStreamingHandler {
self.streaming_handler.as_ref()
}
}
impl HasPushNotifier for TaskService {
fn push_notifier(&self) -> &dyn AsyncPushNotifier {
self.push_notifier.as_ref()
}
}