a2a_protocol_client/methods/
send_message.rs1use a2a_protocol_types::params::SendMessageConfiguration;
9use a2a_protocol_types::{MessageSendParams, SendMessageResponse};
10
11use crate::client::A2aClient;
12use crate::config::ClientConfig;
13use crate::error::{ClientError, ClientResult};
14use crate::interceptor::{ClientRequest, ClientResponse};
15use crate::streaming::EventStream;
16
17fn apply_client_config(params: &mut MessageSendParams, config: &ClientConfig) {
22 let cfg = params
23 .configuration
24 .get_or_insert_with(SendMessageConfiguration::default);
25
26 if cfg.return_immediately.is_none() && config.return_immediately {
28 cfg.return_immediately = Some(true);
29 }
30 if cfg.history_length.is_none() {
31 if let Some(hl) = config.history_length {
32 cfg.history_length = Some(hl);
33 }
34 }
35 if cfg.accepted_output_modes.is_empty() && !config.accepted_output_modes.is_empty() {
36 cfg.accepted_output_modes
37 .clone_from(&config.accepted_output_modes);
38 }
39}
40
41impl A2aClient {
42 pub async fn send_message(
58 &self,
59 mut params: MessageSendParams,
60 ) -> ClientResult<SendMessageResponse> {
61 const METHOD: &str = "SendMessage";
62
63 apply_client_config(&mut params, &self.config);
64
65 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
66
67 let mut req = ClientRequest::new(METHOD, params_value);
68 self.interceptors.run_before(&mut req).await?;
69
70 let result = self
71 .transport
72 .send_request(METHOD, req.params, &req.extra_headers)
73 .await?;
74
75 let resp = ClientResponse {
76 method: METHOD.to_owned(),
77 result,
78 status_code: 200,
79 };
80 self.interceptors.run_after(&resp).await?;
81
82 serde_json::from_value::<SendMessageResponse>(resp.result)
83 .map_err(ClientError::Serialization)
84 }
85
86 pub async fn stream_message(&self, mut params: MessageSendParams) -> ClientResult<EventStream> {
96 const METHOD: &str = "SendStreamingMessage";
97
98 apply_client_config(&mut params, &self.config);
99
100 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
101
102 let mut req = ClientRequest::new(METHOD, params_value);
103 self.interceptors.run_before(&mut req).await?;
104
105 let stream = self
106 .transport
107 .send_streaming_request(METHOD, req.params, &req.extra_headers)
108 .await?;
109
110 let resp = ClientResponse {
118 method: METHOD.to_owned(),
119 result: serde_json::Value::Null,
120 status_code: stream.status_code(),
121 };
122 self.interceptors.run_after(&resp).await?;
123
124 Ok(stream)
125 }
126}
127
128#[cfg(test)]
131mod tests {
132 use super::*;
133 use a2a_protocol_types::{Message, MessageId, MessageRole, Part};
134
135 fn make_params() -> MessageSendParams {
136 MessageSendParams {
137 tenant: None,
138 context_id: None,
139 message: Message {
140 id: MessageId::new("msg-1"),
141 role: MessageRole::User,
142 parts: vec![Part::text("test")],
143 task_id: None,
144 context_id: None,
145 reference_task_ids: None,
146 extensions: None,
147 metadata: None,
148 },
149 configuration: None,
150 metadata: None,
151 }
152 }
153
154 #[test]
155 fn apply_config_sets_return_immediately_when_absent() {
156 let config = ClientConfig {
157 return_immediately: true,
158 ..ClientConfig::default()
159 };
160
161 let mut params = make_params();
162 apply_client_config(&mut params, &config);
163
164 let cfg = params.configuration.unwrap();
165 assert_eq!(cfg.return_immediately, Some(true));
166 }
167
168 #[test]
169 fn apply_config_does_not_override_per_request_return_immediately() {
170 let config = ClientConfig {
171 return_immediately: true,
172 ..ClientConfig::default()
173 };
174
175 let mut params = make_params();
176 params.configuration = Some(SendMessageConfiguration {
177 return_immediately: Some(false),
178 ..Default::default()
179 });
180 apply_client_config(&mut params, &config);
181
182 let cfg = params.configuration.unwrap();
183 assert_eq!(
184 cfg.return_immediately,
185 Some(false),
186 "per-request value should take precedence"
187 );
188 }
189
190 #[test]
191 fn apply_config_does_not_set_return_immediately_when_config_false() {
192 let config = ClientConfig::default(); let mut params = make_params();
195 apply_client_config(&mut params, &config);
196
197 let cfg = params.configuration.unwrap();
198 assert_eq!(
199 cfg.return_immediately, None,
200 "should not set return_immediately when config is false"
201 );
202 }
203
204 #[test]
205 fn apply_config_sets_history_length_when_absent() {
206 let config = ClientConfig {
207 history_length: Some(10),
208 ..ClientConfig::default()
209 };
210
211 let mut params = make_params();
212 apply_client_config(&mut params, &config);
213
214 let cfg = params.configuration.unwrap();
215 assert_eq!(cfg.history_length, Some(10));
216 }
217
218 #[test]
219 fn apply_config_does_not_override_per_request_history_length() {
220 let config = ClientConfig {
221 history_length: Some(10),
222 ..ClientConfig::default()
223 };
224
225 let mut params = make_params();
226 params.configuration = Some(SendMessageConfiguration {
227 history_length: Some(5),
228 ..Default::default()
229 });
230 apply_client_config(&mut params, &config);
231
232 let cfg = params.configuration.unwrap();
233 assert_eq!(cfg.history_length, Some(5));
234 }
235
236 #[test]
237 fn apply_config_sets_accepted_output_modes_when_empty() {
238 let config = ClientConfig {
239 accepted_output_modes: vec!["audio/wav".into()],
240 ..ClientConfig::default()
241 };
242
243 let mut params = make_params();
244 params.configuration = Some(SendMessageConfiguration {
247 accepted_output_modes: vec![],
248 task_push_notification_config: None,
249 history_length: None,
250 return_immediately: None,
251 });
252 apply_client_config(&mut params, &config);
253
254 let cfg = params.configuration.unwrap();
255 assert_eq!(cfg.accepted_output_modes, vec!["audio/wav"]);
256 }
257
258 #[test]
259 fn apply_config_does_not_override_per_request_output_modes() {
260 let config = ClientConfig {
261 accepted_output_modes: vec!["text/plain".into()],
262 ..ClientConfig::default()
263 };
264
265 let mut params = make_params();
266 params.configuration = Some(SendMessageConfiguration {
267 accepted_output_modes: vec!["application/json".into()],
268 ..Default::default()
269 });
270 apply_client_config(&mut params, &config);
271
272 let cfg = params.configuration.unwrap();
273 assert_eq!(cfg.accepted_output_modes, vec!["application/json"]);
274 }
275
276 #[test]
277 fn apply_config_no_op_when_config_has_no_overrides() {
278 let config = ClientConfig::default();
279 let mut params = make_params();
283 params.configuration = Some(SendMessageConfiguration::default());
285 apply_client_config(&mut params, &config);
286
287 let cfg = params.configuration.unwrap();
288 assert_eq!(cfg.return_immediately, None);
290 assert_eq!(cfg.history_length, None);
292 }
293
294 #[tokio::test]
295 async fn stream_message_applies_config_and_calls_transport() {
296 use std::collections::HashMap;
297 use std::future::Future;
298 use std::pin::Pin;
299
300 use crate::error::{ClientError, ClientResult};
301 use crate::streaming::EventStream;
302 use crate::transport::Transport;
303 use crate::ClientBuilder;
304
305 struct StreamCapture;
308
309 impl Transport for StreamCapture {
310 fn send_request<'a>(
311 &'a self,
312 _method: &'a str,
313 _params: serde_json::Value,
314 _extra_headers: &'a HashMap<String, String>,
315 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
316 {
317 Box::pin(async move { Ok(serde_json::Value::Null) })
318 }
319
320 fn send_streaming_request<'a>(
321 &'a self,
322 _method: &'a str,
323 _params: serde_json::Value,
324 _extra_headers: &'a HashMap<String, String>,
325 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
326 Box::pin(
327 async move { Err(ClientError::Transport("mock: streaming called".into())) },
328 )
329 }
330 }
331
332 let client = ClientBuilder::new("http://localhost:8080")
333 .with_custom_transport(StreamCapture)
334 .with_return_immediately(true)
335 .build()
336 .expect("build");
337
338 let params = make_params();
339 let err = client.stream_message(params).await.unwrap_err();
340 assert!(
341 matches!(err, ClientError::Transport(ref msg) if msg.contains("streaming called")),
342 "expected Transport error, got {err:?}"
343 );
344 }
345
346 #[allow(clippy::too_many_lines)]
348 #[tokio::test]
349 async fn stream_message_calls_after_interceptor() {
350 use std::collections::HashMap;
351 use std::future::Future;
352 use std::pin::Pin;
353 use std::sync::atomic::{AtomicUsize, Ordering};
354 use std::sync::Arc;
355
356 use crate::error::ClientResult;
357 use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
358 use crate::streaming::EventStream;
359 use crate::transport::Transport;
360 use crate::ClientBuilder;
361
362 struct StreamingOkTransport;
363
364 impl Transport for StreamingOkTransport {
365 fn send_request<'a>(
366 &'a self,
367 _method: &'a str,
368 _params: serde_json::Value,
369 _extra_headers: &'a HashMap<String, String>,
370 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
371 {
372 Box::pin(async move { Ok(serde_json::Value::Null) })
373 }
374
375 fn send_streaming_request<'a>(
376 &'a self,
377 _method: &'a str,
378 _params: serde_json::Value,
379 _extra_headers: &'a HashMap<String, String>,
380 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
381 Box::pin(async move {
382 let (tx, rx) = tokio::sync::mpsc::channel(8);
383 drop(tx);
384 Ok(EventStream::new(rx))
385 })
386 }
387 }
388
389 struct CountingInterceptor {
390 before_count: Arc<AtomicUsize>,
391 after_count: Arc<AtomicUsize>,
392 }
393
394 impl CallInterceptor for CountingInterceptor {
395 async fn before<'a>(&'a self, _req: &'a mut ClientRequest) -> ClientResult<()> {
396 self.before_count.fetch_add(1, Ordering::SeqCst);
397 Ok(())
398 }
399 async fn after<'a>(&'a self, _resp: &'a ClientResponse) -> ClientResult<()> {
400 self.after_count.fetch_add(1, Ordering::SeqCst);
401 Ok(())
402 }
403 }
404
405 let before = Arc::new(AtomicUsize::new(0));
406 let after = Arc::new(AtomicUsize::new(0));
407 let interceptor = CountingInterceptor {
408 before_count: Arc::clone(&before),
409 after_count: Arc::clone(&after),
410 };
411
412 let client = ClientBuilder::new("http://localhost:8080")
413 .with_custom_transport(StreamingOkTransport)
414 .with_interceptor(interceptor)
415 .build()
416 .expect("build");
417
418 let result = client.stream_message(make_params()).await;
419 assert!(result.is_ok(), "stream_message should succeed");
420 assert_eq!(before.load(Ordering::SeqCst), 1, "before should be called");
421 assert_eq!(
422 after.load(Ordering::SeqCst),
423 1,
424 "after should be called for streaming"
425 );
426 }
427
428 #[test]
429 fn apply_config_does_not_set_modes_when_config_modes_empty() {
430 let config = ClientConfig {
431 accepted_output_modes: vec![],
432 ..ClientConfig::default()
433 };
434
435 let mut params = make_params();
436 params.configuration = Some(SendMessageConfiguration {
438 accepted_output_modes: vec![],
439 task_push_notification_config: None,
440 history_length: None,
441 return_immediately: None,
442 });
443 apply_client_config(&mut params, &config);
444
445 let cfg = params.configuration.unwrap();
446 assert!(
447 cfg.accepted_output_modes.is_empty(),
448 "should not set modes when config modes are empty"
449 );
450 }
451}