use a2a_protocol_types::params::SendMessageConfiguration;
use a2a_protocol_types::{MessageSendParams, SendMessageResponse};
use crate::client::A2aClient;
use crate::config::ClientConfig;
use crate::error::{ClientError, ClientResult};
use crate::interceptor::{ClientRequest, ClientResponse};
use crate::streaming::EventStream;
fn apply_client_config(params: &mut MessageSendParams, config: &ClientConfig) {
let cfg = params
.configuration
.get_or_insert_with(SendMessageConfiguration::default);
if cfg.return_immediately.is_none() && config.return_immediately {
cfg.return_immediately = Some(true);
}
if cfg.history_length.is_none() {
if let Some(hl) = config.history_length {
cfg.history_length = Some(hl);
}
}
if cfg.accepted_output_modes.is_empty() && !config.accepted_output_modes.is_empty() {
cfg.accepted_output_modes
.clone_from(&config.accepted_output_modes);
}
if params.tenant.is_none() {
params.tenant.clone_from(&config.tenant);
}
}
impl A2aClient {
pub async fn send_message(
&self,
mut params: MessageSendParams,
) -> ClientResult<SendMessageResponse> {
const METHOD: &str = "SendMessage";
apply_client_config(&mut params, &self.config);
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let result = self
.transport
.send_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result,
status_code: 200,
};
self.interceptors.run_after(&resp).await?;
serde_json::from_value::<SendMessageResponse>(resp.result)
.map_err(ClientError::Serialization)
}
pub async fn stream_message(&self, mut params: MessageSendParams) -> ClientResult<EventStream> {
const METHOD: &str = "SendStreamingMessage";
apply_client_config(&mut params, &self.config);
let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
let mut req = ClientRequest::new(METHOD, params_value);
self.interceptors.run_before(&mut req).await?;
let stream = self
.transport
.send_streaming_request(METHOD, req.params, &req.extra_headers)
.await?;
let resp = ClientResponse {
method: METHOD.to_owned(),
result: serde_json::Value::Null,
status_code: stream.status_code(),
};
self.interceptors.run_after(&resp).await?;
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::{Message, MessageId, MessageRole, Part};
fn make_params() -> MessageSendParams {
MessageSendParams {
tenant: None,
message: Message {
id: MessageId::new("msg-1"),
role: MessageRole::User,
parts: vec![Part::text("test")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
}
}
#[test]
fn apply_config_sets_return_immediately_when_absent() {
let config = ClientConfig {
return_immediately: true,
..ClientConfig::default()
};
let mut params = make_params();
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.return_immediately, Some(true));
}
#[test]
fn apply_config_does_not_override_per_request_return_immediately() {
let config = ClientConfig {
return_immediately: true,
..ClientConfig::default()
};
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration {
return_immediately: Some(false),
..Default::default()
});
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(
cfg.return_immediately,
Some(false),
"per-request value should take precedence"
);
}
#[test]
fn apply_config_does_not_set_return_immediately_when_config_false() {
let config = ClientConfig::default();
let mut params = make_params();
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(
cfg.return_immediately, None,
"should not set return_immediately when config is false"
);
}
#[test]
fn apply_config_sets_history_length_when_absent() {
let config = ClientConfig {
history_length: Some(10),
..ClientConfig::default()
};
let mut params = make_params();
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.history_length, Some(10));
}
#[test]
fn apply_config_does_not_override_per_request_history_length() {
let config = ClientConfig {
history_length: Some(10),
..ClientConfig::default()
};
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration {
history_length: Some(5),
..Default::default()
});
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.history_length, Some(5));
}
#[test]
fn apply_config_sets_accepted_output_modes_when_empty() {
let config = ClientConfig {
accepted_output_modes: vec!["audio/wav".into()],
..ClientConfig::default()
};
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration {
accepted_output_modes: vec![],
task_push_notification_config: None,
history_length: None,
return_immediately: None,
});
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.accepted_output_modes, vec!["audio/wav"]);
}
#[test]
fn apply_config_does_not_override_per_request_output_modes() {
let config = ClientConfig {
accepted_output_modes: vec!["text/plain".into()],
..ClientConfig::default()
};
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration {
accepted_output_modes: vec!["application/json".into()],
..Default::default()
});
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.accepted_output_modes, vec!["application/json"]);
}
#[test]
fn apply_config_no_op_when_config_has_no_overrides() {
let config = ClientConfig::default();
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration::default());
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert_eq!(cfg.return_immediately, None);
assert_eq!(cfg.history_length, None);
}
#[tokio::test]
async fn stream_message_applies_config_and_calls_transport() {
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use crate::transport::Transport;
use crate::ClientBuilder;
struct StreamCapture;
impl Transport for StreamCapture {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
{
Box::pin(async move { Ok(serde_json::Value::Null) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(
async move { Err(ClientError::Transport("mock: streaming called".into())) },
)
}
}
let client = ClientBuilder::new("http://localhost:8080")
.with_custom_transport(StreamCapture)
.with_return_immediately(true)
.build()
.expect("build");
let params = make_params();
let err = client.stream_message(params).await.unwrap_err();
assert!(
matches!(err, ClientError::Transport(ref msg) if msg.contains("streaming called")),
"expected Transport error, got {err:?}"
);
}
#[allow(clippy::too_many_lines)]
#[tokio::test]
async fn stream_message_calls_after_interceptor() {
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::error::ClientResult;
use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
use crate::streaming::EventStream;
use crate::transport::Transport;
use crate::ClientBuilder;
struct StreamingOkTransport;
impl Transport for StreamingOkTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
{
Box::pin(async move { Ok(serde_json::Value::Null) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(async move {
let (tx, rx) = tokio::sync::mpsc::channel(8);
drop(tx);
Ok(EventStream::new(rx))
})
}
}
struct CountingInterceptor {
before_count: Arc<AtomicUsize>,
after_count: Arc<AtomicUsize>,
}
impl CallInterceptor for CountingInterceptor {
async fn before<'a>(&'a self, _req: &'a mut ClientRequest) -> ClientResult<()> {
self.before_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn after<'a>(&'a self, _resp: &'a ClientResponse) -> ClientResult<()> {
self.after_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
let before = Arc::new(AtomicUsize::new(0));
let after = Arc::new(AtomicUsize::new(0));
let interceptor = CountingInterceptor {
before_count: Arc::clone(&before),
after_count: Arc::clone(&after),
};
let client = ClientBuilder::new("http://localhost:8080")
.with_custom_transport(StreamingOkTransport)
.with_interceptor(interceptor)
.build()
.expect("build");
let result = client.stream_message(make_params()).await;
assert!(result.is_ok(), "stream_message should succeed");
assert_eq!(before.load(Ordering::SeqCst), 1, "before should be called");
assert_eq!(
after.load(Ordering::SeqCst),
1,
"after should be called for streaming"
);
}
#[test]
fn apply_config_does_not_set_modes_when_config_modes_empty() {
let config = ClientConfig {
accepted_output_modes: vec![],
..ClientConfig::default()
};
let mut params = make_params();
params.configuration = Some(SendMessageConfiguration {
accepted_output_modes: vec![],
task_push_notification_config: None,
history_length: None,
return_immediately: None,
});
apply_client_config(&mut params, &config);
let cfg = params.configuration.unwrap();
assert!(
cfg.accepted_output_modes.is_empty(),
"should not set modes when config modes are empty"
);
}
}