1use async_trait::async_trait;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use std::str::FromStr as _;
4use std::time::Duration;
5
6use super::error::{map_jsonrpc_error, A2AClientError, A2AClientResult};
7use super::jsonrpc::{methods, JsonRpcId, JsonRpcRequest, JsonRpcResponse};
8use super::sse::stream_response_from_sse;
9use super::sse::A2AStream;
10use super::types::{
11 AgentCard, CancelTaskRequest, GetTaskRequest, SendMessageRequest, SendMessageResponse, Task,
12};
13
14#[derive(Debug, Clone)]
15pub struct A2AClientConfig {
16 pub profile_id: String,
18 pub agent_card_url: String,
20 pub rpc_url_override: Option<String>,
22 pub auth: A2AAuth,
24 pub tenant: Option<String>,
26 pub request_timeout: Duration,
28 pub protocol_version: String,
30 pub extensions: Vec<String>,
32}
33
34#[derive(Debug, Clone)]
35pub enum A2AAuth {
36 None,
37 Bearer(String),
38 ApiKeyHeader { header: String, value: String },
39}
40
41#[async_trait]
42pub trait A2AClient: Send + Sync {
43 async fn fetch_agent_card(&self) -> A2AClientResult<AgentCard>;
44 async fn send_message(
45 &self,
46 request: SendMessageRequest,
47 ) -> A2AClientResult<SendMessageResponse>;
48 async fn send_streaming_message(
49 &self,
50 request: SendMessageRequest,
51 ) -> A2AClientResult<A2AStream>;
52 async fn get_task(&self, request: GetTaskRequest) -> A2AClientResult<Task>;
53 async fn cancel_task(&self, request: CancelTaskRequest) -> A2AClientResult<Task>;
54}
55
56pub struct A2AJsonRpcClient {
57 http: reqwest::Client,
58 config: A2AClientConfig,
59 resolved_rpc_url: tokio::sync::RwLock<Option<String>>,
60}
61
62impl A2AJsonRpcClient {
63 pub fn new(config: A2AClientConfig) -> A2AClientResult<Self> {
64 let http = reqwest::Client::builder()
65 .timeout(config.request_timeout)
66 .build()
67 .map_err(A2AClientError::Http)?;
68 Ok(Self {
69 http,
70 config,
71 resolved_rpc_url: tokio::sync::RwLock::new(None),
72 })
73 }
74
75 pub fn new_with_http_client(http: reqwest::Client, config: A2AClientConfig) -> Self {
76 Self {
77 http,
78 config,
79 resolved_rpc_url: tokio::sync::RwLock::new(None),
80 }
81 }
82
83 fn build_headers(&self, accept_streaming: bool) -> A2AClientResult<HeaderMap> {
84 let mut headers = HeaderMap::new();
85 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
86 let accept = if accept_streaming {
87 "text/event-stream"
88 } else {
89 "application/json"
90 };
91 headers.insert(
92 "Accept",
93 HeaderValue::from_str(accept).map_err(|e| {
94 A2AClientError::InvalidStreamResponse(format!("Invalid accept header: {}", e))
95 })?,
96 );
97
98 match &self.config.auth {
99 A2AAuth::None => {}
100 A2AAuth::Bearer(token) => {
101 let value = format!("Bearer {}", token);
102 headers.insert(
103 AUTHORIZATION,
104 HeaderValue::from_str(&value).map_err(|e| {
105 A2AClientError::InvalidStreamResponse(format!(
106 "Invalid authorization header: {}",
107 e
108 ))
109 })?,
110 );
111 }
112 A2AAuth::ApiKeyHeader { header, value } => {
113 let name = reqwest::header::HeaderName::from_str(header).map_err(|e| {
114 A2AClientError::InvalidStreamResponse(format!(
115 "Invalid API key header name: {}",
116 e
117 ))
118 })?;
119 headers.insert(
120 name,
121 HeaderValue::from_str(value).map_err(|e| {
122 A2AClientError::InvalidStreamResponse(format!(
123 "Invalid API key header value: {}",
124 e
125 ))
126 })?,
127 );
128 }
129 }
130
131 headers.insert(
132 "A2A-Version",
133 HeaderValue::from_str(&self.config.protocol_version).map_err(|e| {
134 A2AClientError::InvalidStreamResponse(format!("Invalid A2A-Version header: {}", e))
135 })?,
136 );
137
138 if !self.config.extensions.is_empty() {
139 let extensions = self.config.extensions.join(",");
140 headers.insert(
141 "A2A-Extensions",
142 HeaderValue::from_str(&extensions).map_err(|e| {
143 A2AClientError::InvalidStreamResponse(format!(
144 "Invalid A2A-Extensions header: {}",
145 e
146 ))
147 })?,
148 );
149 }
150
151 Ok(headers)
152 }
153
154 async fn resolve_rpc_url(&self) -> A2AClientResult<String> {
155 {
157 let cache = self.resolved_rpc_url.read().await;
158 if let Some(url) = cache.as_ref() {
159 return Ok(url.clone());
160 }
161 }
162
163 if let Some(override_url) = &self.config.rpc_url_override {
165 let mut cache = self.resolved_rpc_url.write().await;
166 cache.replace(override_url.clone());
167 return Ok(override_url.clone());
168 }
169
170 let card = self.fetch_agent_card().await?;
172 let jsonrpc_interface = card
173 .supported_interfaces
174 .into_iter()
175 .find(|iface| iface.protocol_binding.eq_ignore_ascii_case("JSONRPC"))
176 .ok_or_else(|| {
177 A2AClientError::InvalidAgentCard("Agent Card has no JSONRPC interface".to_string())
178 })?;
179
180 let major = jsonrpc_interface
181 .protocol_version
182 .split('.')
183 .next()
184 .and_then(|s| s.parse::<u32>().ok())
185 .ok_or_else(|| {
186 A2AClientError::InvalidAgentCard(format!(
187 "Invalid protocol version: {}",
188 jsonrpc_interface.protocol_version
189 ))
190 })?;
191 if major != 1 {
192 return Err(A2AClientError::VersionNotSupported(format!(
193 "Protocol major version {} != 1",
194 major
195 )));
196 }
197
198 let mut cache = self.resolved_rpc_url.write().await;
199 cache.replace(jsonrpc_interface.url.clone());
200 Ok(jsonrpc_interface.url)
201 }
202
203 fn make_request_id(&self) -> JsonRpcId {
204 JsonRpcId::String(uuid::Uuid::new_v4().to_string())
205 }
206
207 async fn do_jsonrpc_call<Req, Resp>(
208 &self,
209 method: &'static str,
210 params: Req,
211 ) -> A2AClientResult<Resp>
212 where
213 Req: serde::Serialize + Send,
214 Resp: serde::de::DeserializeOwned,
215 {
216 let url = self.resolve_rpc_url().await?;
217 let headers = self.build_headers(false)?;
218 let request = JsonRpcRequest {
219 jsonrpc: super::jsonrpc::JSONRPC_VERSION,
220 id: self.make_request_id(),
221 method,
222 params: Some(params),
223 };
224
225 let body = serde_json::to_string(&request).map_err(A2AClientError::Json)?;
226 let response = self
227 .http
228 .post(&url)
229 .headers(headers)
230 .body(body)
231 .send()
232 .await
233 .map_err(A2AClientError::Http)?;
234
235 if !response.status().is_success() {
236 let status = response.status();
237 let text = response.text().await.unwrap_or_default();
238 return Err(A2AClientError::Sse(format!(
239 "HTTP error {}: {}",
240 status, text
241 )));
242 }
243
244 let body = response.bytes().await.map_err(A2AClientError::Http)?;
245 let envelope: JsonRpcResponse<Resp> =
246 serde_json::from_slice(&body).map_err(A2AClientError::Json)?;
247
248 if let Some(err) = envelope.error {
249 return Err(map_jsonrpc_error(err, None));
250 }
251
252 envelope.result.ok_or_else(|| {
253 A2AClientError::InvalidStreamResponse(
254 "missing result and error in JSON-RPC response".to_string(),
255 )
256 })
257 }
258}
259
260#[async_trait]
261impl A2AClient for A2AJsonRpcClient {
262 async fn fetch_agent_card(&self) -> A2AClientResult<AgentCard> {
263 let response = self
264 .http
265 .get(&self.config.agent_card_url)
266 .send()
267 .await
268 .map_err(A2AClientError::Http)?;
269
270 if !response.status().is_success() {
271 let status = response.status();
272 let text = response.text().await.unwrap_or_default();
273 return Err(A2AClientError::Sse(format!(
274 "HTTP error {} fetching agent card: {}",
275 status, text
276 )));
277 }
278
279 response.json().await.map_err(A2AClientError::Http)
280 }
281
282 async fn send_message(
283 &self,
284 request: SendMessageRequest,
285 ) -> A2AClientResult<SendMessageResponse> {
286 self.do_jsonrpc_call(methods::SEND_MESSAGE, request).await
287 }
288
289 async fn send_streaming_message(
290 &self,
291 request: SendMessageRequest,
292 ) -> A2AClientResult<A2AStream> {
293 let url = self.resolve_rpc_url().await?;
294 let headers = self.build_headers(true)?;
295 let jsonrpc_request = JsonRpcRequest {
296 jsonrpc: super::jsonrpc::JSONRPC_VERSION,
297 id: self.make_request_id(),
298 method: methods::SEND_STREAMING_MESSAGE,
299 params: Some(request),
300 };
301
302 let body = serde_json::to_string(&jsonrpc_request).map_err(A2AClientError::Json)?;
303 let response = self
304 .http
305 .post(&url)
306 .headers(headers)
307 .body(body)
308 .send()
309 .await
310 .map_err(A2AClientError::Http)?;
311
312 if !response.status().is_success() {
313 let status = response.status();
314 let text = response.text().await.unwrap_or_default();
315 return Err(A2AClientError::Sse(format!(
316 "HTTP error {}: {}",
317 status, text
318 )));
319 }
320
321 Ok(stream_response_from_sse(response))
322 }
323
324 async fn get_task(&self, request: GetTaskRequest) -> A2AClientResult<Task> {
325 self.do_jsonrpc_call(methods::GET_TASK, request).await
326 }
327
328 async fn cancel_task(&self, request: CancelTaskRequest) -> A2AClientResult<Task> {
329 self.do_jsonrpc_call(methods::CANCEL_TASK, request).await
330 }
331}