1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize, de::Deserializer, ser::Serializer};
6use serde_json::Value;
7
8pub const ERROR_INFO_TYPE: &str = "type.googleapis.com/google.rpc.ErrorInfo";
9pub const BAD_REQUEST_TYPE: &str = "type.googleapis.com/google.rpc.BadRequest";
10pub const STRUCT_TYPE: &str = "type.googleapis.com/google.protobuf.Struct";
11pub const PROTOCOL_DOMAIN: &str = "a2a-protocol.org";
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct FieldViolation {
16 pub field: String,
17 pub description: String,
18}
19
20#[derive(Debug, Clone, PartialEq)]
26pub struct TypedDetail {
27 pub type_url: String,
28 pub value: HashMap<String, Value>,
29}
30
31impl TypedDetail {
32 pub fn new(type_url: impl Into<String>, value: HashMap<String, Value>) -> Self {
33 Self {
34 type_url: type_url.into(),
35 value,
36 }
37 }
38
39 pub fn error_info(
41 reason: impl Into<String>,
42 domain: impl Into<String>,
43 metadata: Option<HashMap<String, String>>,
44 ) -> Self {
45 let mut value = HashMap::new();
46 value.insert("reason".to_string(), Value::String(reason.into()));
47 value.insert("domain".to_string(), Value::String(domain.into()));
48 if let Some(meta) = metadata {
49 let meta_obj: serde_json::Map<String, Value> = meta
50 .into_iter()
51 .map(|(k, v)| (k, Value::String(v)))
52 .collect();
53 value.insert("metadata".to_string(), Value::Object(meta_obj));
54 }
55 Self {
56 type_url: ERROR_INFO_TYPE.to_string(),
57 value,
58 }
59 }
60
61 pub fn bad_request(field_violations: Vec<FieldViolation>) -> Self {
63 let violations: Vec<Value> = field_violations
64 .into_iter()
65 .map(|fv| serde_json::to_value(fv).unwrap_or_default())
66 .collect();
67 let mut value = HashMap::new();
68 value.insert("fieldViolations".to_string(), Value::Array(violations));
69 Self {
70 type_url: BAD_REQUEST_TYPE.to_string(),
71 value,
72 }
73 }
74
75 pub fn from_struct(fields: HashMap<String, Value>) -> Self {
77 Self {
78 type_url: STRUCT_TYPE.to_string(),
79 value: fields,
80 }
81 }
82}
83
84impl Serialize for TypedDetail {
85 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
86 use serde::ser::SerializeMap;
87 let mut map = serializer.serialize_map(Some(self.value.len() + 1))?;
88 map.serialize_entry("@type", &self.type_url)?;
89 for (k, v) in &self.value {
90 map.serialize_entry(k, v)?;
91 }
92 map.end()
93 }
94}
95
96impl<'de> Deserialize<'de> for TypedDetail {
97 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
98 let mut map: HashMap<String, Value> = HashMap::deserialize(deserializer)?;
99 let type_url = map
100 .remove("@type")
101 .and_then(|v| v.as_str().map(String::from))
102 .unwrap_or_else(|| STRUCT_TYPE.to_string());
103 Ok(Self {
104 type_url,
105 value: map,
106 })
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_serialize_typed_detail() {
116 let detail = TypedDetail::error_info("TASK_NOT_FOUND", PROTOCOL_DOMAIN, None);
117 let json = serde_json::to_value(&detail).unwrap();
118 assert_eq!(json["@type"], ERROR_INFO_TYPE);
119 assert_eq!(json["reason"], "TASK_NOT_FOUND");
120 assert_eq!(json["domain"], PROTOCOL_DOMAIN);
121 }
122
123 #[test]
124 fn test_deserialize_typed_detail() {
125 let json = serde_json::json!({
126 "@type": ERROR_INFO_TYPE,
127 "reason": "TASK_NOT_FOUND",
128 "domain": PROTOCOL_DOMAIN,
129 "metadata": {"taskId": "t1"}
130 });
131 let detail: TypedDetail = serde_json::from_value(json).unwrap();
132 assert_eq!(detail.type_url, ERROR_INFO_TYPE);
133 assert_eq!(detail.value["reason"], "TASK_NOT_FOUND");
134 assert_eq!(detail.value["domain"], PROTOCOL_DOMAIN);
135 assert_eq!(detail.value["metadata"]["taskId"], "t1");
136 }
137
138 #[test]
139 fn test_deserialize_without_type_defaults_to_struct() {
140 let json = serde_json::json!({"resource": "task"});
141 let detail: TypedDetail = serde_json::from_value(json).unwrap();
142 assert_eq!(detail.type_url, STRUCT_TYPE);
143 assert_eq!(detail.value["resource"], "task");
144 }
145
146 #[test]
147 fn test_round_trip() {
148 let meta = HashMap::from([("taskId".to_string(), "t1".to_string())]);
149 let detail = TypedDetail::error_info("TASK_NOT_FOUND", PROTOCOL_DOMAIN, Some(meta));
150 let serialized = serde_json::to_value(&detail).unwrap();
151 let deserialized: TypedDetail = serde_json::from_value(serialized).unwrap();
152 assert_eq!(detail, deserialized);
153 }
154
155 #[test]
156 fn test_bad_request() {
157 let violations = vec![
158 FieldViolation {
159 field: "message.parts".into(),
160 description: "At least one part is required".into(),
161 },
162 FieldViolation {
163 field: "message.role".into(),
164 description: "Role must be 'user' or 'agent'".into(),
165 },
166 ];
167 let detail = TypedDetail::bad_request(violations);
168 assert_eq!(detail.type_url, BAD_REQUEST_TYPE);
169
170 let json = serde_json::to_value(&detail).unwrap();
171 assert_eq!(json["@type"], BAD_REQUEST_TYPE);
172 let fv = json["fieldViolations"].as_array().unwrap();
173 assert_eq!(fv.len(), 2);
174 assert_eq!(fv[0]["field"], "message.parts");
175 assert_eq!(fv[0]["description"], "At least one part is required");
176 assert_eq!(fv[1]["field"], "message.role");
177 }
178
179 #[test]
180 fn test_bad_request_round_trip() {
181 let violations = vec![FieldViolation {
182 field: "task.id".into(),
183 description: "Must not be empty".into(),
184 }];
185 let detail = TypedDetail::bad_request(violations);
186 let serialized = serde_json::to_value(&detail).unwrap();
187 let deserialized: TypedDetail = serde_json::from_value(serialized).unwrap();
188 assert_eq!(detail, deserialized);
189 }
190
191 #[test]
192 fn test_from_struct() {
193 let mut fields = HashMap::new();
194 fields.insert("key".to_string(), Value::String("val".to_string()));
195 let detail = TypedDetail::from_struct(fields.clone());
196 assert_eq!(detail.type_url, STRUCT_TYPE);
197 assert_eq!(detail.value, fields);
198 }
199}