use a2a_protocol_types::{
CancelTaskParams, ListTasksParams, Task, TaskIdParams, TaskListResponse, TaskQueryParams,
};
use crate::client::A2aClient;
use crate::error::{ClientError, ClientResult};
use crate::interceptor::{ClientRequest, ClientResponse};
use crate::streaming::EventStream;
impl A2aClient {
pub async fn get_task(&self, params: TaskQueryParams) -> ClientResult<Task> {
const METHOD: &str = "GetTask";
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let result = self
.transport
.send_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result,
status_code: 200,
};
self.interceptors.run_after(&resp).await?;
serde_json::from_value::<Task>(resp.result).map_err(ClientError::Serialization)
}
pub async fn list_tasks(&self, params: ListTasksParams) -> ClientResult<TaskListResponse> {
const METHOD: &str = "ListTasks";
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let result = self
.transport
.send_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result,
status_code: 200,
};
self.interceptors.run_after(&resp).await?;
serde_json::from_value::<TaskListResponse>(resp.result).map_err(ClientError::Serialization)
}
pub async fn cancel_task(&self, id: impl Into<String>) -> ClientResult<Task> {
const METHOD: &str = "CancelTask";
let params = CancelTaskParams {
tenant: None,
id: id.into(),
metadata: None,
};
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let result = self
.transport
.send_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result,
status_code: 200,
};
self.interceptors.run_after(&resp).await?;
serde_json::from_value::<Task>(resp.result).map_err(ClientError::Serialization)
}
pub async fn subscribe_to_task(&self, id: impl Into<String>) -> ClientResult<EventStream> {
const METHOD: &str = "SubscribeToTask";
let params = TaskIdParams {
tenant: None,
id: id.into(),
};
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let stream = self
.transport
.send_streaming_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result: serde_json::Value::Null,
status_code: 200,
};
self.interceptors.run_after(&resp).await?;
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use a2a_protocol_types::{ListTasksParams, TaskQueryParams};
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use crate::transport::Transport;
use crate::ClientBuilder;
struct MockTransport {
response: serde_json::Value,
}
impl MockTransport {
fn new(response: serde_json::Value) -> Self {
Self { response }
}
}
impl Transport for MockTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
let resp = self.response.clone();
Box::pin(async move { Ok(resp) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(async move {
Err(ClientError::Transport(
"mock: streaming not supported".into(),
))
})
}
}
fn make_client(transport: impl Transport) -> crate::A2aClient {
ClientBuilder::new("http://localhost:8080")
.with_custom_transport(transport)
.build()
.expect("build client")
}
fn task_json() -> serde_json::Value {
serde_json::json!({
"id": "task-1",
"contextId": "ctx-1",
"status": {
"state": "TASK_STATE_COMPLETED"
}
})
}
#[tokio::test]
async fn get_task_success() {
let transport = MockTransport::new(task_json());
let client = make_client(transport);
let params = TaskQueryParams {
tenant: None,
id: "task-1".into(),
history_length: None,
};
let task = client.get_task(params).await.unwrap();
assert_eq!(task.id.as_ref(), "task-1");
}
#[tokio::test]
async fn list_tasks_success() {
let response = serde_json::json!({
"tasks": [
{
"id": "task-1",
"contextId": "ctx-1",
"status": { "state": "TASK_STATE_COMPLETED" }
},
{
"id": "task-2",
"contextId": "ctx-2",
"status": { "state": "TASK_STATE_WORKING" }
}
]
});
let transport = MockTransport::new(response);
let client = make_client(transport);
let params = ListTasksParams::default();
let result = client.list_tasks(params).await.unwrap();
assert_eq!(result.tasks.len(), 2);
assert_eq!(result.tasks[0].id.as_ref(), "task-1");
}
#[tokio::test]
async fn cancel_task_success() {
let transport = MockTransport::new(task_json());
let client = make_client(transport);
let task = client.cancel_task("task-1").await.unwrap();
assert_eq!(task.id.as_ref(), "task-1");
}
#[tokio::test]
async fn subscribe_to_task_calls_after_interceptor() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
struct StreamingOkTransport;
impl Transport for StreamingOkTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
{
Box::pin(async move { Ok(serde_json::Value::Null) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(async move {
let (tx, rx) = tokio::sync::mpsc::channel(8);
drop(tx);
Ok(EventStream::new(rx))
})
}
}
struct CountingInterceptor {
before_count: Arc<AtomicUsize>,
after_count: Arc<AtomicUsize>,
}
impl CallInterceptor for CountingInterceptor {
async fn before<'a>(&'a self, _req: &'a mut ClientRequest) -> ClientResult<()> {
self.before_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn after<'a>(&'a self, _resp: &'a ClientResponse) -> ClientResult<()> {
self.after_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
let before = Arc::new(AtomicUsize::new(0));
let after = Arc::new(AtomicUsize::new(0));
let interceptor = CountingInterceptor {
before_count: Arc::clone(&before),
after_count: Arc::clone(&after),
};
let client = ClientBuilder::new("http://localhost:8080")
.with_custom_transport(StreamingOkTransport)
.with_interceptor(interceptor)
.build()
.expect("build");
let result = client.subscribe_to_task("task-1").await;
assert!(result.is_ok(), "subscribe should succeed");
assert_eq!(before.load(Ordering::SeqCst), 1, "before should be called");
assert_eq!(
after.load(Ordering::SeqCst),
1,
"after should be called for subscribe streaming"
);
}
#[tokio::test]
async fn subscribe_to_task_returns_transport_error() {
let transport = MockTransport::new(serde_json::Value::Null);
let client = make_client(transport);
let err = client.subscribe_to_task("task-1").await.unwrap_err();
assert!(
matches!(err, ClientError::Transport(ref msg) if msg.contains("streaming not supported")),
"expected Transport error, got {err:?}"
);
}
}