use std::collections::HashMap;
use std::time::Instant;
use a2a_protocol_types::params::TaskIdParams;
use a2a_protocol_types::task::TaskId;
use crate::error::{ServerError, ServerResult};
use crate::streaming::InMemoryQueueReader;
use super::super::helpers::build_call_context;
use super::super::RequestHandler;
impl RequestHandler {
pub async fn on_resubscribe(
&self,
params: TaskIdParams,
headers: Option<&HashMap<String, String>>,
) -> ServerResult<InMemoryQueueReader> {
let start = Instant::now();
trace_info!(method = "SubscribeToTask", task_id = %params.id, "handling resubscribe");
self.metrics.on_request("SubscribeToTask");
let tenant = params.tenant.clone().unwrap_or_default();
let result: ServerResult<_> = crate::store::tenant::TenantContext::scope(tenant, async {
let call_ctx = build_call_context("SubscribeToTask", headers);
self.interceptors.run_before(&call_ctx).await?;
let task_id = TaskId::new(¶ms.id);
let _task = self
.task_store
.get(&task_id)
.await?
.ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
let reader = self
.event_queue_manager
.subscribe(&task_id)
.await
.ok_or_else(|| ServerError::Internal("no active event queue for task".into()))?;
self.interceptors.run_after(&call_ctx).await?;
Ok(reader)
})
.await;
let elapsed = start.elapsed();
match &result {
Ok(_) => {
self.metrics.on_response("SubscribeToTask");
self.metrics.on_latency("SubscribeToTask", elapsed);
}
Err(e) => {
self.metrics.on_error("SubscribeToTask", &e.to_string());
self.metrics.on_latency("SubscribeToTask", elapsed);
}
}
result
}
}
#[cfg(test)]
mod tests {
use a2a_protocol_types::params::TaskIdParams;
use crate::agent_executor;
use crate::builder::RequestHandlerBuilder;
use crate::error::ServerError;
struct DummyExecutor;
agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
#[tokio::test]
async fn resubscribe_task_not_found_returns_error() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = TaskIdParams {
tenant: None,
id: "nonexistent-task".to_owned(),
};
let result = handler.on_resubscribe(params, None).await;
assert!(
matches!(result, Err(ServerError::TaskNotFound(_))),
"expected TaskNotFound for missing task, got: {result:?}"
);
}
#[tokio::test]
async fn resubscribe_task_exists_but_no_queue_returns_internal_error() {
use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let task = Task {
id: TaskId::new("t-resub-1"),
context_id: ContextId::new("ctx-1"),
status: TaskStatus::new(TaskState::Completed),
history: None,
artifacts: None,
metadata: None,
};
handler.task_store.save(task).await.unwrap();
let params = TaskIdParams {
tenant: None,
id: "t-resub-1".to_owned(),
};
let result = handler.on_resubscribe(params, None).await;
assert!(
matches!(result, Err(ServerError::Internal(_))),
"expected Internal error when no event queue exists, got: {result:?}"
);
}
#[tokio::test]
async fn resubscribe_success_returns_reader() {
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
use a2a_protocol_types::params::MessageSendParams;
use a2a_protocol_types::task::ContextId;
use crate::handler::SendMessageResult;
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = MessageSendParams {
context_id: None,
message: Message {
id: MessageId::new("msg-resub"),
role: MessageRole::User,
parts: vec![Part::text("hello")],
context_id: Some(ContextId::new("ctx-resub")),
task_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
tenant: None,
};
let result = handler.on_send_message(params, true, None).await;
assert!(matches!(result, Ok(SendMessageResult::Stream(_))));
let tasks = handler
.task_store
.list(&a2a_protocol_types::params::ListTasksParams::default())
.await
.unwrap();
assert!(!tasks.tasks.is_empty(), "should have at least one task");
let task_id = tasks.tasks[0].id.0.clone();
let sub_params = TaskIdParams {
tenant: None,
id: task_id,
};
let sub_result = handler.on_resubscribe(sub_params, None).await;
match &sub_result {
Ok(_) | Err(ServerError::Internal(_)) => {} Err(e) => panic!("unexpected error: {e:?}"),
}
}
#[tokio::test]
async fn resubscribe_with_tenant() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = TaskIdParams {
tenant: Some("test-tenant".to_string()),
id: "nonexistent-task".to_owned(),
};
let result = handler.on_resubscribe(params, None).await;
assert!(result.is_err(), "resubscribe for missing task should fail");
}
#[tokio::test]
async fn resubscribe_with_headers() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = TaskIdParams {
tenant: None,
id: "nonexistent-task".to_owned(),
};
let mut headers = std::collections::HashMap::new();
headers.insert("authorization".to_string(), "Bearer tok".to_string());
let result = handler.on_resubscribe(params, Some(&headers)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn resubscribe_error_path_records_error_metrics() {
use crate::call_context::CallContext;
use crate::interceptor::ServerInterceptor;
use std::future::Future;
use std::pin::Pin;
struct FailInterceptor;
impl ServerInterceptor for FailInterceptor {
fn before<'a>(
&'a self,
_ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
{
Box::pin(async {
Err(a2a_protocol_types::error::A2aError::internal(
"forced failure",
))
})
}
fn after<'a>(
&'a self,
_ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
{
Box::pin(async { Ok(()) })
}
}
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_interceptor(FailInterceptor)
.build()
.unwrap();
let params = TaskIdParams {
tenant: None,
id: "t-resub-fail".to_owned(),
};
let result = handler.on_resubscribe(params, None).await;
assert!(
result.is_err(),
"resubscribe should fail when interceptor rejects"
);
}
}