use crate::types::{
AgentCard, GetTaskParams, ListTasksParams, ListTasksResult, Message, PushNotificationConfig,
SendMessageParams, SendMessageResult, Task,
};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("transport error: {0}")]
Transport(#[from] reqwest::Error),
#[error("a2a peer returned HTTP {code}: {body}")]
Status { code: u16, body: String },
#[error("serialization error: {0}")]
Serialize(#[from] serde_json::Error),
#[error("a2a peer returned error code {code}: {message}")]
Rpc { code: i32, message: String },
#[error("a2a peer returned an unexpected result shape: {0}")]
BadResultShape(String),
#[error("a2a peer returned malformed response: {0}")]
Malformed(String),
}
#[derive(Debug, Clone)]
pub enum ClientAuth {
None,
Bearer(String),
Header { name: String, value: String },
}
impl Default for ClientAuth {
fn default() -> Self {
Self::None
}
}
pub struct A2aClient {
http: reqwest::Client,
base_url: String,
auth: ClientAuth,
next_id: AtomicU64,
}
impl A2aClient {
pub fn new(base_url: impl Into<String>) -> Self {
let base = base_url.into();
let trimmed = base.trim_end_matches('/').to_string();
Self {
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("default reqwest client must build"),
base_url: trimmed,
auth: ClientAuth::None,
next_id: AtomicU64::new(1),
}
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub fn with_auth(mut self, auth: ClientAuth) -> Self {
self.auth = auth;
self
}
pub async fn agent_card(&self) -> Result<AgentCard, ClientError> {
let url = format!("{}/.well-known/agent-card.json", self.base_url);
let resp = self.apply_auth(self.http.get(url)).send().await?;
let body: Value = resp.json().await?;
Ok(serde_json::from_value(body)?)
}
pub async fn send_message(
&self,
message: Message,
blocking: bool,
) -> Result<SendMessageResult, ClientError> {
let params = SendMessageParams {
message,
configuration: Some(crate::types::MessageConfiguration {
blocking,
..Default::default()
}),
};
self.call("SendMessage", ¶ms).await
}
pub async fn get_task(&self, task_id: impl Into<String>) -> Result<Task, ClientError> {
let params = GetTaskParams {
id: task_id.into(),
history_length: None,
};
self.call("GetTask", ¶ms).await
}
pub async fn list_tasks(&self, filter: ListTasksParams) -> Result<ListTasksResult, ClientError> {
self.call("ListTasks", &filter).await
}
pub async fn cancel_task(&self, task_id: impl Into<String>) -> Result<Task, ClientError> {
#[derive(Serialize)]
struct P {
id: String,
}
self.call("CancelTask", &P { id: task_id.into() }).await
}
pub async fn set_push_config(
&self,
task_id: impl Into<String>,
config: PushNotificationConfig,
) -> Result<String, ClientError> {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct P {
task_id: String,
config: PushNotificationConfig,
}
let resp: Value = self
.call(
"CreateTaskPushNotificationConfig",
&P {
task_id: task_id.into(),
config,
},
)
.await?;
resp.get("configId")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| ClientError::Malformed("missing configId in response".into()))
}
pub async fn call<P: Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: &P,
) -> Result<R, ClientError> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let body = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": id,
});
let resp = self
.apply_auth(self.http.post(&self.base_url).json(&body))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
let truncated: String = body.chars().take(1024).collect();
return Err(ClientError::Status {
code: status.as_u16(),
body: redact_credentials(&truncated),
});
}
let envelope: Value = resp.json().await?;
if let Some(err) = envelope.get("error") {
let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
let message = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("(no message)")
.to_string();
return Err(ClientError::Rpc { code, message });
}
let result = envelope
.get("result")
.ok_or_else(|| ClientError::Malformed("missing `result` field".into()))?
.clone();
serde_json::from_value(result).map_err(|e| ClientError::BadResultShape(e.to_string()))
}
fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.auth {
ClientAuth::None => req,
ClientAuth::Bearer(token) => req.bearer_auth(token),
ClientAuth::Header { name, value } => req.header(name.as_str(), value.as_str()),
}
}
}
fn redact_credentials(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for line in input.split_inclusive('\n') {
let trimmed = line.trim_start();
let lower = trimmed.to_ascii_lowercase();
let is_creds = lower.starts_with("authorization:")
|| lower.starts_with("authorization\":")
|| lower.starts_with("\"authorization\":")
|| lower.starts_with("x-api-key:")
|| lower.starts_with("\"x-api-key\":")
|| lower.starts_with("api-key:")
|| lower.starts_with("bearer ")
|| lower.contains(" bearer ");
if is_creds {
if let Some(idx) = line.find(':') {
out.push_str(&line[..=idx]);
out.push_str(" [REDACTED]");
if line.ends_with('\n') {
out.push('\n');
}
} else {
out.push_str("[REDACTED]");
if line.ends_with('\n') {
out.push('\n');
}
}
} else {
out.push_str(line);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{MessageRole, Part, TextPart};
use std::collections::HashMap;
#[test]
fn client_trims_trailing_slash() {
let c = A2aClient::new("http://example.test/");
assert_eq!(c.base_url, "http://example.test");
}
#[test]
fn client_auth_default_is_none() {
let c = A2aClient::new("http://x");
assert!(matches!(c.auth, ClientAuth::None));
}
#[test]
fn redact_credentials_strips_authorization_lines() {
let body = "HTTP/1.1 401 Unauthorized\r\nAuthorization: Bearer SECRET123\r\nX-API-Key: KEY456\r\nContent-Type: text/plain\r\n\r\nrequest rejected";
let red = redact_credentials(body);
assert!(!red.contains("SECRET123"), "bearer not redacted: {}", red);
assert!(!red.contains("KEY456"), "api key not redacted: {}", red);
assert!(red.contains("[REDACTED]"));
assert!(red.contains("Content-Type: text/plain"));
}
#[test]
fn message_can_be_constructed_for_call() {
let _msg = Message {
message_id: "m".into(),
role: MessageRole::User,
parts: vec![Part::Text(TextPart {
text: "hi".into(),
metadata: HashMap::new(),
})],
task_id: None,
context_id: None,
metadata: HashMap::new(),
};
}
}