a2a_protocol_client/methods/
send_message.rs1use a2a_protocol_types::params::SendMessageConfiguration;
9use a2a_protocol_types::{MessageSendParams, SendMessageResponse};
10
11use crate::client::A2aClient;
12use crate::config::ClientConfig;
13use crate::error::{ClientError, ClientResult};
14use crate::interceptor::{ClientRequest, ClientResponse};
15use crate::streaming::EventStream;
16
17fn apply_client_config(params: &mut MessageSendParams, config: &ClientConfig) {
22 let cfg = params
23 .configuration
24 .get_or_insert_with(SendMessageConfiguration::default);
25
26 if cfg.return_immediately.is_none() && config.return_immediately {
28 cfg.return_immediately = Some(true);
29 }
30 if cfg.history_length.is_none() {
31 if let Some(hl) = config.history_length {
32 cfg.history_length = Some(hl);
33 }
34 }
35 if cfg.accepted_output_modes.is_empty() && !config.accepted_output_modes.is_empty() {
36 cfg.accepted_output_modes
37 .clone_from(&config.accepted_output_modes);
38 }
39 if params.tenant.is_none() {
41 params.tenant.clone_from(&config.tenant);
42 }
43}
44
45impl A2aClient {
46 pub async fn send_message(
62 &self,
63 mut params: MessageSendParams,
64 ) -> ClientResult<SendMessageResponse> {
65 const METHOD: &str = "SendMessage";
66
67 apply_client_config(&mut params, &self.config);
68
69 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
70
71 let mut req = ClientRequest::new(METHOD, params_value);
72 self.interceptors.run_before(&mut req).await?;
73
74 let result = self
75 .transport
76 .send_request(METHOD, req.params, &req.extra_headers)
77 .await?;
78
79 let resp = ClientResponse {
80 method: METHOD.to_owned(),
81 result,
82 status_code: 200,
83 };
84 self.interceptors.run_after(&resp).await?;
85
86 serde_json::from_value::<SendMessageResponse>(resp.result)
87 .map_err(ClientError::Serialization)
88 }
89
90 pub async fn stream_message(&self, mut params: MessageSendParams) -> ClientResult<EventStream> {
100 const METHOD: &str = "SendStreamingMessage";
101
102 apply_client_config(&mut params, &self.config);
103
104 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
105
106 let mut req = ClientRequest::new(METHOD, params_value);
107 self.interceptors.run_before(&mut req).await?;
108
109 let stream = self
110 .transport
111 .send_streaming_request(METHOD, req.params, &req.extra_headers)
112 .await?;
113
114 let resp = ClientResponse {
122 method: METHOD.to_owned(),
123 result: serde_json::Value::Null,
124 status_code: stream.status_code(),
125 };
126 self.interceptors.run_after(&resp).await?;
127
128 Ok(stream)
129 }
130}
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137 use a2a_protocol_types::{Message, MessageId, MessageRole, Part};
138
139 fn make_params() -> MessageSendParams {
140 MessageSendParams {
141 tenant: None,
142 message: Message {
143 id: MessageId::new("msg-1"),
144 role: MessageRole::User,
145 parts: vec![Part::text("test")],
146 task_id: None,
147 context_id: None,
148 reference_task_ids: None,
149 extensions: None,
150 metadata: None,
151 },
152 configuration: None,
153 metadata: None,
154 }
155 }
156
157 #[test]
158 fn apply_config_sets_return_immediately_when_absent() {
159 let config = ClientConfig {
160 return_immediately: true,
161 ..ClientConfig::default()
162 };
163
164 let mut params = make_params();
165 apply_client_config(&mut params, &config);
166
167 let cfg = params.configuration.unwrap();
168 assert_eq!(cfg.return_immediately, Some(true));
169 }
170
171 #[test]
172 fn apply_config_does_not_override_per_request_return_immediately() {
173 let config = ClientConfig {
174 return_immediately: true,
175 ..ClientConfig::default()
176 };
177
178 let mut params = make_params();
179 params.configuration = Some(SendMessageConfiguration {
180 return_immediately: Some(false),
181 ..Default::default()
182 });
183 apply_client_config(&mut params, &config);
184
185 let cfg = params.configuration.unwrap();
186 assert_eq!(
187 cfg.return_immediately,
188 Some(false),
189 "per-request value should take precedence"
190 );
191 }
192
193 #[test]
194 fn apply_config_does_not_set_return_immediately_when_config_false() {
195 let config = ClientConfig::default(); let mut params = make_params();
198 apply_client_config(&mut params, &config);
199
200 let cfg = params.configuration.unwrap();
201 assert_eq!(
202 cfg.return_immediately, None,
203 "should not set return_immediately when config is false"
204 );
205 }
206
207 #[test]
208 fn apply_config_sets_history_length_when_absent() {
209 let config = ClientConfig {
210 history_length: Some(10),
211 ..ClientConfig::default()
212 };
213
214 let mut params = make_params();
215 apply_client_config(&mut params, &config);
216
217 let cfg = params.configuration.unwrap();
218 assert_eq!(cfg.history_length, Some(10));
219 }
220
221 #[test]
222 fn apply_config_does_not_override_per_request_history_length() {
223 let config = ClientConfig {
224 history_length: Some(10),
225 ..ClientConfig::default()
226 };
227
228 let mut params = make_params();
229 params.configuration = Some(SendMessageConfiguration {
230 history_length: Some(5),
231 ..Default::default()
232 });
233 apply_client_config(&mut params, &config);
234
235 let cfg = params.configuration.unwrap();
236 assert_eq!(cfg.history_length, Some(5));
237 }
238
239 #[test]
240 fn apply_config_sets_accepted_output_modes_when_empty() {
241 let config = ClientConfig {
242 accepted_output_modes: vec!["audio/wav".into()],
243 ..ClientConfig::default()
244 };
245
246 let mut params = make_params();
247 params.configuration = Some(SendMessageConfiguration {
250 accepted_output_modes: vec![],
251 task_push_notification_config: None,
252 history_length: None,
253 return_immediately: None,
254 });
255 apply_client_config(&mut params, &config);
256
257 let cfg = params.configuration.unwrap();
258 assert_eq!(cfg.accepted_output_modes, vec!["audio/wav"]);
259 }
260
261 #[test]
262 fn apply_config_does_not_override_per_request_output_modes() {
263 let config = ClientConfig {
264 accepted_output_modes: vec!["text/plain".into()],
265 ..ClientConfig::default()
266 };
267
268 let mut params = make_params();
269 params.configuration = Some(SendMessageConfiguration {
270 accepted_output_modes: vec!["application/json".into()],
271 ..Default::default()
272 });
273 apply_client_config(&mut params, &config);
274
275 let cfg = params.configuration.unwrap();
276 assert_eq!(cfg.accepted_output_modes, vec!["application/json"]);
277 }
278
279 #[test]
280 fn apply_config_no_op_when_config_has_no_overrides() {
281 let config = ClientConfig::default();
282 let mut params = make_params();
286 params.configuration = Some(SendMessageConfiguration::default());
288 apply_client_config(&mut params, &config);
289
290 let cfg = params.configuration.unwrap();
291 assert_eq!(cfg.return_immediately, None);
293 assert_eq!(cfg.history_length, None);
295 }
296
297 #[tokio::test]
298 async fn stream_message_applies_config_and_calls_transport() {
299 use std::collections::HashMap;
300 use std::future::Future;
301 use std::pin::Pin;
302
303 use crate::error::{ClientError, ClientResult};
304 use crate::streaming::EventStream;
305 use crate::transport::Transport;
306 use crate::ClientBuilder;
307
308 struct StreamCapture;
311
312 impl Transport for StreamCapture {
313 fn send_request<'a>(
314 &'a self,
315 _method: &'a str,
316 _params: serde_json::Value,
317 _extra_headers: &'a HashMap<String, String>,
318 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
319 {
320 Box::pin(async move { Ok(serde_json::Value::Null) })
321 }
322
323 fn send_streaming_request<'a>(
324 &'a self,
325 _method: &'a str,
326 _params: serde_json::Value,
327 _extra_headers: &'a HashMap<String, String>,
328 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
329 Box::pin(
330 async move { Err(ClientError::Transport("mock: streaming called".into())) },
331 )
332 }
333 }
334
335 let client = ClientBuilder::new("http://localhost:8080")
336 .with_custom_transport(StreamCapture)
337 .with_return_immediately(true)
338 .build()
339 .expect("build");
340
341 let params = make_params();
342 let err = client.stream_message(params).await.unwrap_err();
343 assert!(
344 matches!(err, ClientError::Transport(ref msg) if msg.contains("streaming called")),
345 "expected Transport error, got {err:?}"
346 );
347 }
348
349 #[allow(clippy::too_many_lines)]
351 #[tokio::test]
352 async fn stream_message_calls_after_interceptor() {
353 use std::collections::HashMap;
354 use std::future::Future;
355 use std::pin::Pin;
356 use std::sync::atomic::{AtomicUsize, Ordering};
357 use std::sync::Arc;
358
359 use crate::error::ClientResult;
360 use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
361 use crate::streaming::EventStream;
362 use crate::transport::Transport;
363 use crate::ClientBuilder;
364
365 struct StreamingOkTransport;
366
367 impl Transport for StreamingOkTransport {
368 fn send_request<'a>(
369 &'a self,
370 _method: &'a str,
371 _params: serde_json::Value,
372 _extra_headers: &'a HashMap<String, String>,
373 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
374 {
375 Box::pin(async move { Ok(serde_json::Value::Null) })
376 }
377
378 fn send_streaming_request<'a>(
379 &'a self,
380 _method: &'a str,
381 _params: serde_json::Value,
382 _extra_headers: &'a HashMap<String, String>,
383 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
384 Box::pin(async move {
385 let (tx, rx) = tokio::sync::mpsc::channel(8);
386 drop(tx);
387 Ok(EventStream::new(rx))
388 })
389 }
390 }
391
392 struct CountingInterceptor {
393 before_count: Arc<AtomicUsize>,
394 after_count: Arc<AtomicUsize>,
395 }
396
397 impl CallInterceptor for CountingInterceptor {
398 async fn before<'a>(&'a self, _req: &'a mut ClientRequest) -> ClientResult<()> {
399 self.before_count.fetch_add(1, Ordering::SeqCst);
400 Ok(())
401 }
402 async fn after<'a>(&'a self, _resp: &'a ClientResponse) -> ClientResult<()> {
403 self.after_count.fetch_add(1, Ordering::SeqCst);
404 Ok(())
405 }
406 }
407
408 let before = Arc::new(AtomicUsize::new(0));
409 let after = Arc::new(AtomicUsize::new(0));
410 let interceptor = CountingInterceptor {
411 before_count: Arc::clone(&before),
412 after_count: Arc::clone(&after),
413 };
414
415 let client = ClientBuilder::new("http://localhost:8080")
416 .with_custom_transport(StreamingOkTransport)
417 .with_interceptor(interceptor)
418 .build()
419 .expect("build");
420
421 let result = client.stream_message(make_params()).await;
422 assert!(result.is_ok(), "stream_message should succeed");
423 assert_eq!(before.load(Ordering::SeqCst), 1, "before should be called");
424 assert_eq!(
425 after.load(Ordering::SeqCst),
426 1,
427 "after should be called for streaming"
428 );
429 }
430
431 #[test]
432 fn apply_config_does_not_set_modes_when_config_modes_empty() {
433 let config = ClientConfig {
434 accepted_output_modes: vec![],
435 ..ClientConfig::default()
436 };
437
438 let mut params = make_params();
439 params.configuration = Some(SendMessageConfiguration {
441 accepted_output_modes: vec![],
442 task_push_notification_config: None,
443 history_length: None,
444 return_immediately: None,
445 });
446 apply_client_config(&mut params, &config);
447
448 let cfg = params.configuration.unwrap();
449 assert!(
450 cfg.accepted_output_modes.is_empty(),
451 "should not set modes when config modes are empty"
452 );
453 }
454}