Skip to main content

bamboo_a2a/
client.rs

1use async_trait::async_trait;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use std::str::FromStr as _;
4use std::time::Duration;
5
6use super::error::{map_jsonrpc_error, A2AClientError, A2AClientResult};
7use super::jsonrpc::{methods, JsonRpcId, JsonRpcRequest, JsonRpcResponse};
8use super::sse::stream_response_from_sse;
9use super::sse::A2AStream;
10use super::types::{
11    AgentCard, CancelTaskRequest, GetTaskRequest, SendMessageRequest, SendMessageResponse, Task,
12};
13
14#[derive(Debug, Clone)]
15pub struct A2AClientConfig {
16    /// Stable Bamboo-side profile id, e.g. "remote-impl".
17    pub profile_id: String,
18    /// Agent Card discovery URL.
19    pub agent_card_url: String,
20    /// Optional RPC URL override. If absent, pick first JSONRPC interface from Agent Card.
21    pub rpc_url_override: Option<String>,
22    /// Authentication material.
23    pub auth: A2AAuth,
24    /// Optional tenant header/path param.
25    pub tenant: Option<String>,
26    /// Request timeout for non-streaming calls.
27    pub request_timeout: Duration,
28    /// Optional A2A-Version header. For v1.0 use "1.0".
29    pub protocol_version: String,
30    /// Optional required extensions to advertise via A2A-Extensions header.
31    pub extensions: Vec<String>,
32}
33
34#[derive(Debug, Clone)]
35pub enum A2AAuth {
36    None,
37    Bearer(String),
38    ApiKeyHeader { header: String, value: String },
39}
40
41#[async_trait]
42pub trait A2AClient: Send + Sync {
43    async fn fetch_agent_card(&self) -> A2AClientResult<AgentCard>;
44    async fn send_message(
45        &self,
46        request: SendMessageRequest,
47    ) -> A2AClientResult<SendMessageResponse>;
48    async fn send_streaming_message(
49        &self,
50        request: SendMessageRequest,
51    ) -> A2AClientResult<A2AStream>;
52    async fn get_task(&self, request: GetTaskRequest) -> A2AClientResult<Task>;
53    async fn cancel_task(&self, request: CancelTaskRequest) -> A2AClientResult<Task>;
54}
55
56pub struct A2AJsonRpcClient {
57    http: reqwest::Client,
58    config: A2AClientConfig,
59    resolved_rpc_url: tokio::sync::RwLock<Option<String>>,
60}
61
62impl A2AJsonRpcClient {
63    pub fn new(config: A2AClientConfig) -> A2AClientResult<Self> {
64        let http = reqwest::Client::builder()
65            .timeout(config.request_timeout)
66            .build()
67            .map_err(A2AClientError::Http)?;
68        Ok(Self {
69            http,
70            config,
71            resolved_rpc_url: tokio::sync::RwLock::new(None),
72        })
73    }
74
75    pub fn new_with_http_client(http: reqwest::Client, config: A2AClientConfig) -> Self {
76        Self {
77            http,
78            config,
79            resolved_rpc_url: tokio::sync::RwLock::new(None),
80        }
81    }
82
83    fn build_headers(&self, accept_streaming: bool) -> A2AClientResult<HeaderMap> {
84        let mut headers = HeaderMap::new();
85        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
86        let accept = if accept_streaming {
87            "text/event-stream"
88        } else {
89            "application/json"
90        };
91        headers.insert(
92            "Accept",
93            HeaderValue::from_str(accept).map_err(|e| {
94                A2AClientError::InvalidStreamResponse(format!("Invalid accept header: {}", e))
95            })?,
96        );
97
98        match &self.config.auth {
99            A2AAuth::None => {}
100            A2AAuth::Bearer(token) => {
101                let value = format!("Bearer {}", token);
102                headers.insert(
103                    AUTHORIZATION,
104                    HeaderValue::from_str(&value).map_err(|e| {
105                        A2AClientError::InvalidStreamResponse(format!(
106                            "Invalid authorization header: {}",
107                            e
108                        ))
109                    })?,
110                );
111            }
112            A2AAuth::ApiKeyHeader { header, value } => {
113                let name = reqwest::header::HeaderName::from_str(header).map_err(|e| {
114                    A2AClientError::InvalidStreamResponse(format!(
115                        "Invalid API key header name: {}",
116                        e
117                    ))
118                })?;
119                headers.insert(
120                    name,
121                    HeaderValue::from_str(value).map_err(|e| {
122                        A2AClientError::InvalidStreamResponse(format!(
123                            "Invalid API key header value: {}",
124                            e
125                        ))
126                    })?,
127                );
128            }
129        }
130
131        headers.insert(
132            "A2A-Version",
133            HeaderValue::from_str(&self.config.protocol_version).map_err(|e| {
134                A2AClientError::InvalidStreamResponse(format!("Invalid A2A-Version header: {}", e))
135            })?,
136        );
137
138        if !self.config.extensions.is_empty() {
139            let extensions = self.config.extensions.join(",");
140            headers.insert(
141                "A2A-Extensions",
142                HeaderValue::from_str(&extensions).map_err(|e| {
143                    A2AClientError::InvalidStreamResponse(format!(
144                        "Invalid A2A-Extensions header: {}",
145                        e
146                    ))
147                })?,
148            );
149        }
150
151        Ok(headers)
152    }
153
154    async fn resolve_rpc_url(&self) -> A2AClientResult<String> {
155        // Check cache first
156        {
157            let cache = self.resolved_rpc_url.read().await;
158            if let Some(url) = cache.as_ref() {
159                return Ok(url.clone());
160            }
161        }
162
163        // Use override if configured
164        if let Some(override_url) = &self.config.rpc_url_override {
165            let mut cache = self.resolved_rpc_url.write().await;
166            cache.replace(override_url.clone());
167            return Ok(override_url.clone());
168        }
169
170        // Fetch Agent Card and find JSONRPC interface
171        let card = self.fetch_agent_card().await?;
172        let jsonrpc_interface = card
173            .supported_interfaces
174            .into_iter()
175            .find(|iface| iface.protocol_binding.eq_ignore_ascii_case("JSONRPC"))
176            .ok_or_else(|| {
177                A2AClientError::InvalidAgentCard("Agent Card has no JSONRPC interface".to_string())
178            })?;
179
180        let major = jsonrpc_interface
181            .protocol_version
182            .split('.')
183            .next()
184            .and_then(|s| s.parse::<u32>().ok())
185            .ok_or_else(|| {
186                A2AClientError::InvalidAgentCard(format!(
187                    "Invalid protocol version: {}",
188                    jsonrpc_interface.protocol_version
189                ))
190            })?;
191        if major != 1 {
192            return Err(A2AClientError::VersionNotSupported(format!(
193                "Protocol major version {} != 1",
194                major
195            )));
196        }
197
198        let mut cache = self.resolved_rpc_url.write().await;
199        cache.replace(jsonrpc_interface.url.clone());
200        Ok(jsonrpc_interface.url)
201    }
202
203    fn make_request_id(&self) -> JsonRpcId {
204        JsonRpcId::String(uuid::Uuid::new_v4().to_string())
205    }
206
207    async fn do_jsonrpc_call<Req, Resp>(
208        &self,
209        method: &'static str,
210        params: Req,
211    ) -> A2AClientResult<Resp>
212    where
213        Req: serde::Serialize + Send,
214        Resp: serde::de::DeserializeOwned,
215    {
216        let url = self.resolve_rpc_url().await?;
217        let headers = self.build_headers(false)?;
218        let request = JsonRpcRequest {
219            jsonrpc: super::jsonrpc::JSONRPC_VERSION,
220            id: self.make_request_id(),
221            method,
222            params: Some(params),
223        };
224
225        let body = serde_json::to_string(&request).map_err(A2AClientError::Json)?;
226        let response = self
227            .http
228            .post(&url)
229            .headers(headers)
230            .body(body)
231            .send()
232            .await
233            .map_err(A2AClientError::Http)?;
234
235        if !response.status().is_success() {
236            let status = response.status();
237            let text = response.text().await.unwrap_or_default();
238            return Err(A2AClientError::Sse(format!(
239                "HTTP error {}: {}",
240                status, text
241            )));
242        }
243
244        let body = response.bytes().await.map_err(A2AClientError::Http)?;
245        let envelope: JsonRpcResponse<Resp> =
246            serde_json::from_slice(&body).map_err(A2AClientError::Json)?;
247
248        if let Some(err) = envelope.error {
249            return Err(map_jsonrpc_error(err, None));
250        }
251
252        envelope.result.ok_or_else(|| {
253            A2AClientError::InvalidStreamResponse(
254                "missing result and error in JSON-RPC response".to_string(),
255            )
256        })
257    }
258}
259
260#[async_trait]
261impl A2AClient for A2AJsonRpcClient {
262    async fn fetch_agent_card(&self) -> A2AClientResult<AgentCard> {
263        let response = self
264            .http
265            .get(&self.config.agent_card_url)
266            .send()
267            .await
268            .map_err(A2AClientError::Http)?;
269
270        if !response.status().is_success() {
271            let status = response.status();
272            let text = response.text().await.unwrap_or_default();
273            return Err(A2AClientError::Sse(format!(
274                "HTTP error {} fetching agent card: {}",
275                status, text
276            )));
277        }
278
279        response.json().await.map_err(A2AClientError::Http)
280    }
281
282    async fn send_message(
283        &self,
284        request: SendMessageRequest,
285    ) -> A2AClientResult<SendMessageResponse> {
286        self.do_jsonrpc_call(methods::SEND_MESSAGE, request).await
287    }
288
289    async fn send_streaming_message(
290        &self,
291        request: SendMessageRequest,
292    ) -> A2AClientResult<A2AStream> {
293        let url = self.resolve_rpc_url().await?;
294        let headers = self.build_headers(true)?;
295        let jsonrpc_request = JsonRpcRequest {
296            jsonrpc: super::jsonrpc::JSONRPC_VERSION,
297            id: self.make_request_id(),
298            method: methods::SEND_STREAMING_MESSAGE,
299            params: Some(request),
300        };
301
302        let body = serde_json::to_string(&jsonrpc_request).map_err(A2AClientError::Json)?;
303        let response = self
304            .http
305            .post(&url)
306            .headers(headers)
307            .body(body)
308            .send()
309            .await
310            .map_err(A2AClientError::Http)?;
311
312        if !response.status().is_success() {
313            let status = response.status();
314            let text = response.text().await.unwrap_or_default();
315            return Err(A2AClientError::Sse(format!(
316                "HTTP error {}: {}",
317                status, text
318            )));
319        }
320
321        Ok(stream_response_from_sse(response))
322    }
323
324    async fn get_task(&self, request: GetTaskRequest) -> A2AClientResult<Task> {
325        self.do_jsonrpc_call(methods::GET_TASK, request).await
326    }
327
328    async fn cancel_task(&self, request: CancelTaskRequest) -> A2AClientResult<Task> {
329        self.do_jsonrpc_call(methods::CANCEL_TASK, request).await
330    }
331}